"""Slurm instance provisioning."""

import tempfile
import textwrap
import time
from typing import Any, cast, Dict, List, Optional, Tuple

from sky import sky_logging
from sky.adaptors import slurm
from sky.provision import common
from sky.provision import constants
from sky.provision.slurm import utils as slurm_utils
from sky.utils import command_runner
from sky.utils import common_utils
from sky.utils import status_lib
from sky.utils import subprocess_utils
from sky.utils import timeline

logger = sky_logging.init_logger(__name__)

# TODO(kevin): This assumes $HOME is in a shared filesystem.
# We should probably make it configurable, and add a check
# during sky check.
SHARED_ROOT_SKY_DIRECTORY = '~/.sky_clusters'
PROVISION_SCRIPTS_DIRECTORY_NAME = '.sky_provision'
PROVISION_SCRIPTS_DIRECTORY = f'~/{PROVISION_SCRIPTS_DIRECTORY_NAME}'

POLL_INTERVAL_SECONDS = 2
# Default KillWait is 30 seconds, so we add some buffer time here.
_TIMEOUT_SECONDS_FOR_JOB_TERMINATION = 60


def _sky_cluster_home_dir(cluster_name_on_cloud: str) -> str:
    """Returns the SkyPilot cluster's home directory path on the Slurm cluster.
    """
    return f'{SHARED_ROOT_SKY_DIRECTORY}/{cluster_name_on_cloud}'


def _sbatch_provision_script_path(filename: str) -> str:
    """Returns the path to the sbatch provision script on the login node."""
    # Put sbatch script in $HOME instead of /tmp as there can be
    # multiple login nodes, and different SSH connections
    # can land on different login nodes.
    return f'{PROVISION_SCRIPTS_DIRECTORY}/{filename}'


def _skypilot_runtime_dir(cluster_name_on_cloud: str) -> str:
    """Returns the SkyPilot runtime directory path on the Slurm cluster."""
    return f'/tmp/{cluster_name_on_cloud}'


@timeline.event
def _create_virtual_instance(
        region: str, cluster_name_on_cloud: str,
        config: common.ProvisionConfig) -> common.ProvisionRecord:
    """Creates a Slurm virtual instance from the config.

    A Slurm virtual instance is created by submitting a long-running
    job with sbatch, to mimic a cloud VM.
    """
    provider_config = config.provider_config
    ssh_config_dict = provider_config['ssh']
    ssh_host = ssh_config_dict['hostname']
    ssh_port = int(ssh_config_dict['port'])
    ssh_user = ssh_config_dict['user']
    ssh_key = ssh_config_dict['private_key']
    ssh_proxy_command = ssh_config_dict.get('proxycommand', None)
    partition = slurm_utils.get_partition_from_config(provider_config)

    client = slurm.SlurmClient(
        ssh_host,
        ssh_port,
        ssh_user,
        ssh_key,
        ssh_proxy_command=ssh_proxy_command,
    )

    # COMPLETING state occurs when a job is being terminated - during this
    # phase, slurmd sends SIGTERM to tasks, waits for KillWait period, sends
    # SIGKILL if needed, runs epilog scripts, and notifies slurmctld. This
    # typically happens when a previous job with the same name is being
    # cancelled or has finished. Jobs can get stuck in COMPLETING if epilog
    # scripts hang or tasks don't respond to signals, so we wait with a
    # timeout.
    completing_jobs = client.query_jobs(
        cluster_name_on_cloud,
        ['completing'],
    )
    start_time = time.time()
    while (completing_jobs and
           time.time() - start_time < _TIMEOUT_SECONDS_FOR_JOB_TERMINATION):
        logger.debug(f'Found {len(completing_jobs)} completing jobs. '
                     f'Waiting for them to finish: {completing_jobs}')
        time.sleep(POLL_INTERVAL_SECONDS)
        completing_jobs = client.query_jobs(
            cluster_name_on_cloud,
            ['completing'],
        )
    if completing_jobs:
        # TODO(kevin): Automatically handle this, following the suggestions in
        # https://slurm.schedmd.com/troubleshoot.html#completing
        raise RuntimeError(f'Found {len(completing_jobs)} jobs still in '
                           'completing state after '
                           f'{_TIMEOUT_SECONDS_FOR_JOB_TERMINATION}s. '
                           'This is typically due to non-killable processes '
                           'associated with the job.')

    # Check if job already exists
    existing_jobs = client.query_jobs(
        cluster_name_on_cloud,
        ['pending', 'running'],
    )
    if existing_jobs:
        assert len(existing_jobs) == 1, (
            f'Multiple jobs found with name {cluster_name_on_cloud}: '
            f'{existing_jobs}')

        job_id = existing_jobs[0]
        logger.debug(f'Job with name {cluster_name_on_cloud} already exists '
                     f'(JOBID: {job_id})')

        # Wait for nodes to be allocated (job might be in PENDING state)
        nodes, _ = client.get_job_nodes(job_id, wait=True)
        return common.ProvisionRecord(provider_name='slurm',
                                      region=region,
                                      zone=partition,
                                      cluster_name=cluster_name_on_cloud,
                                      head_instance_id=slurm_utils.instance_id(
                                          job_id, nodes[0]),
                                      resumed_instance_ids=[],
                                      created_instance_ids=[])

    resources = config.node_config
    num_nodes = config.count
    # TODO(kevin): Support multi-node.
    assert num_nodes == 1

    accelerator_type = resources.get('accelerator_type')
    accelerator_count_raw = resources.get('accelerator_count')
    try:
        accelerator_count = int(
            accelerator_count_raw) if accelerator_count_raw is not None else 0
    except (TypeError, ValueError):
        accelerator_count = 0

    skypilot_runtime_dir = _skypilot_runtime_dir(cluster_name_on_cloud)
    sky_dir = _sky_cluster_home_dir(cluster_name_on_cloud)

    # Build the sbatch script
    gpu_directive = ''
    if (accelerator_type is not None and accelerator_type.upper() != 'NONE' and
            accelerator_count > 0):
        gpu_directive = (f'#SBATCH --gres=gpu:{accelerator_type.lower()}:'
                         f'{accelerator_count}')

    # By default stdout and stderr will be written to $HOME/slurm-%j.out
    # (because we invoke sbatch from $HOME). Redirect elsewhere to not pollute
    # the home directory.
    provision_script = textwrap.dedent(f"""\
        #!/bin/bash
        #SBATCH --job-name={cluster_name_on_cloud}
        #SBATCH --output={PROVISION_SCRIPTS_DIRECTORY_NAME}/slurm-%j.out
        #SBATCH --error={PROVISION_SCRIPTS_DIRECTORY_NAME}/slurm-%j.out
        #SBATCH --nodes={num_nodes}
        #SBATCH --wait-all-nodes=1
        #SBATCH --cpus-per-task={int(resources["cpus"])}
        #SBATCH --mem={int(resources["memory"])}G
        {gpu_directive}

        # Cleanup function to remove cluster dirs on job termination.
        cleanup() {{
            # The Skylet is daemonized, so it is not automatically terminated when
            # the Slurm job is terminated, we need to kill it manually.
            echo "Terminating Skylet..."
            if [ -f "{skypilot_runtime_dir}/.sky/skylet_pid" ]; then
                kill $(cat "{skypilot_runtime_dir}/.sky/skylet_pid") 2>/dev/null || true
            fi
            echo "Cleaning up sky directories..."
            rm -rf {skypilot_runtime_dir}
            rm -rf {sky_dir}
        }}
        trap cleanup TERM

        # Create sky directory for the cluster.
        # TODO(kevin): Since this is run inside the sbatch script, failures
        # will not be surfaced in a synchronous way. We should add a check
        # to verify the creation of the directory.
        mkdir -p {sky_dir} {skypilot_runtime_dir}
        # Suppress login messages.
        touch {sky_dir}/.hushlogin
        sleep infinity
        """)

    # To bootstrap things, we need to do it with SSHCommandRunner first.
    # SlurmCommandRunner is for after the virtual instances are created.
    login_node_runner = command_runner.SSHCommandRunner(
        (ssh_host, ssh_port),
        ssh_user,
        ssh_key,
        ssh_proxy_command=ssh_proxy_command,
    )

    cmd = f'mkdir -p {PROVISION_SCRIPTS_DIRECTORY}'
    rc, stdout, stderr = login_node_runner.run(cmd,
                                               require_outputs=True,
                                               stream_logs=False)
    subprocess_utils.handle_returncode(
        rc,
        cmd,
        'Failed to create provision scripts directory on login node.',
        stderr=f'{stdout}\n{stderr}')
    # Rsync the provision script to the login node
    with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=True) as f:
        f.write(provision_script)
        f.flush()
        src_path = f.name
        tgt_path = _sbatch_provision_script_path(f'{cluster_name_on_cloud}.sh')
        login_node_runner.rsync(src_path, tgt_path, up=True, stream_logs=False)

    job_id = client.submit_job(partition, cluster_name_on_cloud, tgt_path)
    logger.debug(f'Successfully submitted Slurm job {job_id} for cluster '
                 f'{cluster_name_on_cloud} with {num_nodes} nodes')

    nodes, _ = client.get_job_nodes(job_id, wait=True)
    created_instance_ids = [
        slurm_utils.instance_id(job_id, node) for node in nodes
    ]

    return common.ProvisionRecord(provider_name='slurm',
                                  region=region,
                                  zone=partition,
                                  cluster_name=cluster_name_on_cloud,
                                  head_instance_id=created_instance_ids[0],
                                  resumed_instance_ids=[],
                                  created_instance_ids=created_instance_ids)


@common_utils.retry
def query_instances(
    cluster_name: str,
    cluster_name_on_cloud: str,
    provider_config: Optional[Dict[str, Any]] = None,
    non_terminated_only: bool = True,
    retry_if_missing: bool = False,
) -> Dict[str, Tuple[Optional[status_lib.ClusterStatus], Optional[str]]]:
    """See sky/provision/__init__.py"""
    del cluster_name, retry_if_missing  # Unused for Slurm
    assert provider_config is not None, (cluster_name_on_cloud, provider_config)

    ssh_config_dict = provider_config['ssh']
    ssh_host = ssh_config_dict['hostname']
    ssh_port = int(ssh_config_dict['port'])
    ssh_user = ssh_config_dict['user']
    ssh_key = ssh_config_dict['private_key']
    ssh_proxy_command = ssh_config_dict.get('proxycommand', None)

    client = slurm.SlurmClient(
        ssh_host,
        ssh_port,
        ssh_user,
        ssh_key,
        ssh_proxy_command=ssh_proxy_command,
    )

    # Map Slurm job states to SkyPilot ClusterStatus
    # Slurm states:
    # https://slurm.schedmd.com/squeue.html#SECTION_JOB-STATE-CODES
    status_map = {
        'pending': status_lib.ClusterStatus.INIT,
        'running': status_lib.ClusterStatus.UP,
        'completing': status_lib.ClusterStatus.UP,
        'completed': None,
        'cancelled': None,
        'failed': status_lib.ClusterStatus.INIT,
    }

    statuses: Dict[str, Tuple[Optional[status_lib.ClusterStatus],
                              Optional[str]]] = {}
    for state, sky_status in status_map.items():
        if non_terminated_only and sky_status is None:
            continue

        jobs = client.query_jobs(
            cluster_name_on_cloud,
            [state],
        )

        for job_id in jobs:
            statuses[job_id] = (sky_status, None)

    return statuses


def run_instances(
        region: str,
        cluster_name: str,  # pylint: disable=unused-argument
        cluster_name_on_cloud: str,
        config: common.ProvisionConfig) -> common.ProvisionRecord:
    """Run instances for the given cluster (Slurm in this case)."""
    return _create_virtual_instance(region, cluster_name_on_cloud, config)


def wait_instances(region: str, cluster_name_on_cloud: str,
                   state: Optional[status_lib.ClusterStatus]) -> None:
    """See sky/provision/__init__.py"""
    del region, cluster_name_on_cloud, state
    # We already wait for the instances to be running in run_instances.
    # So we don't need to wait here.


def get_cluster_info(
        region: str,
        cluster_name_on_cloud: str,
        provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo:
    del region
    assert provider_config is not None, cluster_name_on_cloud

    # The SSH host is the remote machine running slurmctld daemon.
    # Cross-cluster operations are supported by interacting with
    # the current controller. For details, please refer to
    # https://slurm.schedmd.com/multi_cluster.html.
    ssh_config_dict = provider_config['ssh']
    ssh_host = ssh_config_dict['hostname']
    ssh_port = int(ssh_config_dict['port'])
    ssh_user = ssh_config_dict['user']
    ssh_key = ssh_config_dict['private_key']
    ssh_proxy_command = ssh_config_dict.get('proxycommand', None)

    client = slurm.SlurmClient(
        ssh_host,
        ssh_port,
        ssh_user,
        ssh_key,
        ssh_proxy_command=ssh_proxy_command,
    )

    # Find running job for this cluster
    running_jobs = client.query_jobs(
        cluster_name_on_cloud,
        ['running'],
    )

    if not running_jobs:
        # No running jobs found - cluster may be in pending or terminated state
        return common.ClusterInfo(
            instances={},
            head_instance_id=None,
            ssh_user=ssh_user,
            provider_name='slurm',
            provider_config=provider_config,
        )
    assert len(running_jobs) == 1, (
        f'Multiple running jobs found for cluster {cluster_name_on_cloud}: '
        f'{running_jobs}')

    job_id = running_jobs[0]
    # Running jobs should already have nodes allocated, so don't wait
    nodes, node_ips = client.get_job_nodes(job_id, wait=False)

    instances = {
        f'{slurm_utils.instance_id(job_id, node)}': [
            common.InstanceInfo(
                instance_id=slurm_utils.instance_id(job_id, node),
                internal_ip=node_ip,
                external_ip=ssh_host,
                ssh_port=ssh_port,
                tags={
                    constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud,
                    'job_id': job_id,
                    'node': node,
                },
            )
        ] for node, node_ip in zip(nodes, node_ips)
    }

    return common.ClusterInfo(
        instances=instances,
        head_instance_id=slurm_utils.instance_id(job_id, nodes[0]),
        ssh_user=ssh_user,
        provider_name='slurm',
        provider_config=provider_config,
    )


def stop_instances(
    cluster_name_on_cloud: str,
    provider_config: Optional[Dict[str, Any]] = None,
    worker_only: bool = False,
) -> None:
    """Keep the Slurm virtual instances running."""
    raise NotImplementedError()


def terminate_instances(
    cluster_name_on_cloud: str,
    provider_config: Optional[Dict[str, Any]] = None,
    worker_only: bool = False,
) -> None:
    """See sky/provision/__init__.py"""
    assert provider_config is not None, cluster_name_on_cloud

    if worker_only:
        logger.warning(
            'worker_only=True is not supported for Slurm, this is a no-op.')
        return

    ssh_config_dict = provider_config['ssh']
    ssh_host = ssh_config_dict['hostname']
    ssh_port = int(ssh_config_dict['port'])
    ssh_user = ssh_config_dict['user']
    ssh_private_key = ssh_config_dict['private_key']
    # Check if we are running inside a Slurm job (Only happens with autodown,
    # where the Skylet will invoke terminate_instances on the remote cluster),
    # where we assume SSH between nodes have been set up on each node's
    # ssh config.
    # TODO(kevin): Validate this assumption. Another way would be to
    # mount the private key to the remote cluster, like we do with
    # other clouds' API keys.
    if slurm_utils.is_inside_slurm_job():
        logger.debug('Running inside a Slurm job, using machine\'s ssh config')
        ssh_private_key = None
    ssh_proxy_command = ssh_config_dict.get('proxycommand', None)

    client = slurm.SlurmClient(
        ssh_host,
        ssh_port,
        ssh_user,
        ssh_private_key,
        ssh_proxy_command=ssh_proxy_command,
    )
    client.cancel_jobs_by_name(
        cluster_name_on_cloud,
        signal='TERM',
        full=True,
    )


def open_ports(
    cluster_name_on_cloud: str,
    ports: List[str],
    provider_config: Optional[Dict[str, Any]] = None,
) -> None:
    """See sky/provision/__init__.py"""
    del cluster_name_on_cloud, ports, provider_config
    pass


def cleanup_ports(
    cluster_name_on_cloud: str,
    ports: List[str],
    provider_config: Optional[Dict[str, Any]] = None,
) -> None:
    """See sky/provision/__init__.py"""
    del cluster_name_on_cloud, ports, provider_config
    pass


def get_command_runners(
    cluster_info: common.ClusterInfo,
    **credentials: Dict[str, Any],
) -> List[command_runner.SlurmCommandRunner]:
    """Get a command runner for the given cluster."""
    assert cluster_info.provider_config is not None, cluster_info

    if cluster_info.head_instance_id is None:
        # No running job found
        return []

    head_instance = cluster_info.get_head_instance()
    assert head_instance is not None, 'Head instance not found'
    cluster_name_on_cloud = head_instance.tags.get(
        constants.TAG_SKYPILOT_CLUSTER_NAME, None)
    assert cluster_name_on_cloud is not None, cluster_info

    # There can only be one InstanceInfo per instance_id.
    instances = [
        instance_infos[0] for instance_infos in cluster_info.instances.values()
    ]

    # Note: For Slurm, the external IP for all instances is the same,
    # it is the login node's. The internal IP is the private IP of the node.
    ssh_user = cast(str, credentials.pop('ssh_user'))
    ssh_private_key = cast(str, credentials.pop('ssh_private_key'))
    runners = [
        command_runner.SlurmCommandRunner(
            (instance_info.external_ip or '', instance_info.ssh_port),
            ssh_user,
            ssh_private_key,
            sky_dir=_sky_cluster_home_dir(cluster_name_on_cloud),
            skypilot_runtime_dir=_skypilot_runtime_dir(cluster_name_on_cloud),
            job_id=instance_info.tags['job_id'],
            slurm_node=instance_info.tags['node'],
            **credentials) for instance_info in instances
    ]

    return runners
